set(0,'DefaultFigureWindowStyle','docked')
clear

%% load data

load run1a_baseline_sr.mat

prob_Nxz1 = prob_Nxz;
gesol1 = gesol;
parms1 = parms;
Phi_Nxz1 = Phi_Nxz;

load run1b_baseline_no_sr.mat
prob_Nxz2 = prob_Nxz;
gesol2 = gesol;
parms2 = parms;
Phi_Nxz2 = Phi_Nxz;

%% compare recovery in baseline model against model without slow recoveries

% initialize t=0
N01 = sum(prob_Nxz1(:,:,1).*repmat(parms1.grid_N(:),[1 parms1.Nx]),'all')...
    /sum(prob_Nxz1(:,:,1),'all');
x01 = sum(prob_Nxz1(:,:,1).*permute(repmat(parms1.grid_x(:),[1 parms1.NN]),[2 1]),'all')...
    /sum(prob_Nxz1(:,:,1),'all');
N02 = sum(prob_Nxz2(:,:,1).*repmat(parms2.grid_N(:),[1 parms2.Nx]),'all')...
    /sum(prob_Nxz2(:,:,1),'all');
x02 = sum(prob_Nxz2(:,:,1).*permute(repmat(parms2.grid_x(:),[1 parms2.NN]),[2 1]),'all')...
    /sum(prob_Nxz2(:,:,1),'all');

% paths
T = 5*12;
ts = 0:T;
Nt1 = zeros(T+1,1); Nt1(1) = N01;
xt1 = zeros(T+1,1); xt1(1) = x01;
Nt2 = zeros(T+1,1); Nt2(1) = N02;
xt2 = zeros(T+1,1); xt2(1) = x02;
EntropyComponent1 = zeros(T+1,1); EntropyComponent1(1) = NaN;
EntropyComponent2 = zeros(T+1,1); EntropyComponent2(1) = NaN;
BondRiskPrem1 = zeros(T+1,1); BondRiskPrem1(1) = NaN;
BondRiskPrem2 = zeros(T+1,1); BondRiskPrem2(1) = NaN;
PredictableComponent1 = zeros(T+1,1); PredictableComponent1(1) = NaN;
PredictableComponent2 = zeros(T+1,1); PredictableComponent2(1) = NaN;
fwdspread1 = zeros(T+1,1); fwdspread1(1) = NaN;
fwdspread2 = zeros(T+1,1); fwdspread2(1) = NaN;
EdlogM1 = zeros(T+1,1); EdlogM1(1) = NaN;
EdlogM2 = zeros(T+1,1); EdlogM2(1) = NaN;

[NN1,xx1] = meshgrid(parms1.grid_N,parms1.grid_x);
[NN2,xx2] = meshgrid(parms2.grid_N,parms2.grid_x);

% shock occurs at t=1, state switches to 2
spath = zeros(T+1,1); spath(1) = 1; spath(2:end) = 2;
condmean1 = parms1.StateTransitionProbs*parms1.grid_z(:);
condmean2 = parms2.StateTransitionProbs*parms2.grid_z(:);
condstd1 = sqrt(parms1.StateTransitionProbs*(parms1.grid_z(:).^2) - condmean1.^2);
condstd2 = sqrt(parms2.StateTransitionProbs*(parms2.grid_z(:).^2) - condmean2.^2);

theta_t1 = zeros(T+1,1); theta_t1(1) = NaN;
theta_t2 = zeros(T+1,1); theta_t2(1) = NaN;

% E_dlogM = zeros(NN,Nx,Nz,NT); % E_t[log(M(t+T)) - log(M(t))]
idx_1y = find(gesol1.entropy.horizon==12);
entropy_ny = 2; % increment in years
idx_entropy_ny = find(gesol1.entropy.horizon==entropy_ny*12);
idx_entropy_nless1y = find(gesol1.entropy.horizon==entropy_ny*12-12);

% precompute E_t[log M_{t+n-1,t+n}]
EdlogM1fwd = reshape((Phi_Nxz1^((entropy_ny-1)*12))*reshape(gesol1.entropy.E_dlogM(:,:,:,idx_1y),[],1),[parms1.NN parms1.Nx parms1.Nz]);
EdlogM2fwd = reshape((Phi_Nxz2^((entropy_ny-1)*12))*reshape(gesol2.entropy.E_dlogM(:,:,:,idx_1y),[],1),[parms2.NN parms2.Nx parms2.Nz]);

bondrp1_Nxz = calc_rp(gesol1,Phi_Nxz1,entropy_ny*12);
bondrp2_Nxz = calc_rp(gesol2,Phi_Nxz2,entropy_ny*12);

theta_t1(1) = interp2(NN1,xx1,gesol1.theta(:,:,spath(1))',Nt1(1),xt1(1));
theta_t2(1) = interp2(NN2,xx2,gesol2.theta(:,:,spath(1))',Nt2(1),xt2(1));

for t = 2:T+1
    ft1 = interp2(NN1,xx1,gesol1.f(:,:,spath(t-1))',Nt1(t-1),xt1(t-1));
    Nt1(t) = (1-parms1.s(1))*Nt1(t-1) + ft1*(1 - Nt1(t-1));
    xt1(t) = (1-parms1.rho_x)*parms1.x_bar + parms1.rho_x*xt1(t-1) ...
        + parms1.sigma_x(1)*(parms1.grid_z(spath(t)) - condmean1(spath(t-1)))/condstd1(spath(t-1));
    EdlogM1(t) = interp2(NN1,xx1,gesol1.entropy.E_dlogM(:,:,spath(t),idx_1y)',Nt1(t),xt1(t));
    EntropyComponent1(t) = interp2(NN1,xx1,gesol1.entropy.CondEntropy(:,:,spath(t),idx_1y)',Nt1(t),xt1(t))...
        + interp2(NN1,xx1,gesol1.entropy.CondEntropy(:,:,spath(t),idx_entropy_nless1y)',Nt1(t),xt1(t))...
        - interp2(NN1,xx1,gesol1.entropy.CondEntropy(:,:,spath(t),idx_entropy_ny)',Nt1(t),xt1(t));
    PredictableComponent1(t) = EdlogM1(t) - interp2(NN1,xx1,EdlogM1fwd(:,:,spath(t))',Nt1(t),xt1(t));
    BondRiskPrem1(t) = interp2(NN1,xx1,bondrp1_Nxz(:,:,spath(t))',Nt1(t),xt1(t));
    theta_t1(t) = interp2(NN1,xx1,gesol1.theta(:,:,spath(t))',Nt1(t),xt1(t));
%     fwdspread1(t) = interp2(NN1,xx1,fwdspread1_Nxz(:,:,spath(t))',Nt1(t),xt1(t));

    ft2 = interp2(NN2,xx2,gesol2.f(:,:,spath(t-1))',Nt2(t-1),xt2(t-1));
    Nt2(t) = (1-parms2.s(1))*Nt2(t-1) + ft2*(1 - Nt2(t-1));
    xt2(t) = (1-parms2.rho_x)*parms2.x_bar + parms2.rho_x*xt2(t-1) ...
        + parms2.sigma_x(1)*(parms2.grid_z(spath(t)) - condmean2(spath(t-1)))/condstd2(spath(t-1));
    EdlogM2(t) = interp2(NN2,xx2,gesol2.entropy.E_dlogM(:,:,spath(t),idx_1y)',Nt2(t),xt2(t));
    EntropyComponent2(t) = interp2(NN2,xx2,gesol2.entropy.CondEntropy(:,:,spath(t),idx_1y)',Nt2(t),xt2(t))...
        + interp2(NN2,xx2,gesol2.entropy.CondEntropy(:,:,spath(t),idx_entropy_nless1y)',Nt2(t),xt2(t))...
        - interp2(NN2,xx2,gesol2.entropy.CondEntropy(:,:,spath(t),idx_entropy_ny)',Nt2(t),xt2(t));
    PredictableComponent2(t) = EdlogM2(t) - interp2(NN2,xx2,EdlogM2fwd(:,:,spath(t))',Nt2(t),xt2(t));
    BondRiskPrem2(t) = interp2(NN2,xx2,bondrp2_Nxz(:,:,spath(t))',Nt2(t),xt2(t));
    theta_t2(t) = interp2(NN2,xx2,gesol2.theta(:,:,spath(t))',Nt2(t),xt2(t));
end

%% plot, ver 1

LineWidth = 2;
FontSize = 16;

figure

subplot(1,2,1)
hold on
plot(ts(2:end), 100*(log(Nt1(2:end)) - log(Nt1(1))), '-b', 'LineWidth', LineWidth)
plot(ts(2:end), 100*(log(Nt2(2:end)) - log(Nt2(1))), '-.r', 'LineWidth', LineWidth)
xlabel('$t$','Interpreter','latex','FontSize',FontSize)
ylabel('$\log N_t - \log N_0$ (\%)','Interpreter','latex','FontSize',FontSize)
title('A. Employment growth','Interpreter','latex','FontSize',FontSize)
legend({'Baseline','No slow recovery'},'Interpreter','latex','FontSize',FontSize,...
    'Location','southeast')
legend boxoff
box on

subplot(1,2,2)
hold on
plot(ts(2:end), EdlogM1(2:end), '-b', 'LineWidth', LineWidth)
plot(ts(2:end), EdlogM2(2:end), '-.r', 'LineWidth', LineWidth)
xlabel('$t$','Interpreter','latex','FontSize',FontSize)
ylabel('$E_t[\log M_{t,t+12}]$','Interpreter','latex','FontSize',FontSize)
title('B. Expected value of log SDF','Interpreter','latex','FontSize',FontSize)
box on

%% function to compute bond risk premia

function rp = calc_rp(gesol,Phi_Nxz,mat)
    
    idx_buy = find(gesol.real_bonds.maturities==mat);
    idx_sell = find(gesol.real_bonds.maturities==mat-12);
    idx_1y = find(gesol.real_bonds.maturities==12);
    
    p_buy = log(gesol.real_bonds.prices(:,:,:,idx_buy));
    size_p = size(p_buy);
    p_sell = log(gesol.real_bonds.prices(:,:,:,idx_sell));
    p_1y = log(gesol.real_bonds.prices(:,:,:,idx_1y));
    
    rp = reshape((Phi_Nxz^12)*p_sell(:),size_p) - p_buy + p_1y;

end